{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Convolutional Neural Network with He Initialization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.05\n",
    "num_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class ConvNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        # calculate same padding:\n",
    "        # (w - k + 2*p)/s + 1 = o\n",
    "        # => p = (s(o-1) - w + k)/2\n",
    "        \n",
    "        # 28x28x1 => 28x28x4\n",
    "        self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(28-1) - 28 + 3) / 2 = 1\n",
    "        # 28x28x4 => 14x14x4\n",
    "        self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                                         stride=(2, 2),\n",
    "                                         padding=0) # (2(14-1) - 28 + 2) = 0                                       \n",
    "        # 14x14x4 => 14x14x8\n",
    "        self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=8,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(14-1) - 14 + 3) / 2 = 1                 \n",
    "        # 14x14x8 => 7x7x8                             \n",
    "        self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                                         stride=(2, 2),\n",
    "                                         padding=0) # (2(7-1) - 14 + 2) = 0\n",
    "        \n",
    "        self.linear_1 = torch.nn.Linear(7*7*8, num_classes)\n",
    "        \n",
    "        ###############################################\n",
    "        # Reinitialize weights using He initialization\n",
    "        ###############################################\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, torch.nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight.detach())\n",
    "                m.bias.detach().zero_()\n",
    "            elif isinstance(m, torch.nn.Linear):\n",
    "                nn.init.kaiming_normal_(m.weight.detach())\n",
    "                m.bias.detach().zero_()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.conv_1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.pool_1(out)\n",
    "\n",
    "        out = self.conv_2(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.pool_2(out)\n",
    "        \n",
    "        logits = self.linear_1(out.view(-1, 7*7*8))\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = ConvNet(num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/469 | Cost: 2.4577\n",
      "Epoch: 001/010 | Batch 050/469 | Cost: 1.1068\n",
      "Epoch: 001/010 | Batch 100/469 | Cost: 0.6610\n",
      "Epoch: 001/010 | Batch 150/469 | Cost: 0.5354\n",
      "Epoch: 001/010 | Batch 200/469 | Cost: 0.4479\n",
      "Epoch: 001/010 | Batch 250/469 | Cost: 0.3158\n",
      "Epoch: 001/010 | Batch 300/469 | Cost: 0.4542\n",
      "Epoch: 001/010 | Batch 350/469 | Cost: 0.4278\n",
      "Epoch: 001/010 | Batch 400/469 | Cost: 0.1387\n",
      "Epoch: 001/010 | Batch 450/469 | Cost: 0.1410\n",
      "Epoch: 001/010 training accuracy: 91.97%\n",
      "Time elapsed: 0.23 min\n",
      "Epoch: 002/010 | Batch 000/469 | Cost: 0.2198\n",
      "Epoch: 002/010 | Batch 050/469 | Cost: 0.1464\n",
      "Epoch: 002/010 | Batch 100/469 | Cost: 0.2629\n",
      "Epoch: 002/010 | Batch 150/469 | Cost: 0.1920\n",
      "Epoch: 002/010 | Batch 200/469 | Cost: 0.1485\n",
      "Epoch: 002/010 | Batch 250/469 | Cost: 0.1229\n",
      "Epoch: 002/010 | Batch 300/469 | Cost: 0.1591\n",
      "Epoch: 002/010 | Batch 350/469 | Cost: 0.1411\n",
      "Epoch: 002/010 | Batch 400/469 | Cost: 0.1404\n",
      "Epoch: 002/010 | Batch 450/469 | Cost: 0.1211\n",
      "Epoch: 002/010 training accuracy: 95.21%\n",
      "Time elapsed: 0.46 min\n",
      "Epoch: 003/010 | Batch 000/469 | Cost: 0.1289\n",
      "Epoch: 003/010 | Batch 050/469 | Cost: 0.2468\n",
      "Epoch: 003/010 | Batch 100/469 | Cost: 0.1308\n",
      "Epoch: 003/010 | Batch 150/469 | Cost: 0.1887\n",
      "Epoch: 003/010 | Batch 200/469 | Cost: 0.1053\n",
      "Epoch: 003/010 | Batch 250/469 | Cost: 0.1564\n",
      "Epoch: 003/010 | Batch 300/469 | Cost: 0.1235\n",
      "Epoch: 003/010 | Batch 350/469 | Cost: 0.1388\n",
      "Epoch: 003/010 | Batch 400/469 | Cost: 0.1556\n",
      "Epoch: 003/010 | Batch 450/469 | Cost: 0.1658\n",
      "Epoch: 003/010 training accuracy: 96.45%\n",
      "Time elapsed: 0.69 min\n",
      "Epoch: 004/010 | Batch 000/469 | Cost: 0.1827\n",
      "Epoch: 004/010 | Batch 050/469 | Cost: 0.0613\n",
      "Epoch: 004/010 | Batch 100/469 | Cost: 0.1967\n",
      "Epoch: 004/010 | Batch 150/469 | Cost: 0.1072\n",
      "Epoch: 004/010 | Batch 200/469 | Cost: 0.1063\n",
      "Epoch: 004/010 | Batch 250/469 | Cost: 0.0970\n",
      "Epoch: 004/010 | Batch 300/469 | Cost: 0.0593\n",
      "Epoch: 004/010 | Batch 350/469 | Cost: 0.1031\n",
      "Epoch: 004/010 | Batch 400/469 | Cost: 0.1503\n",
      "Epoch: 004/010 | Batch 450/469 | Cost: 0.1611\n",
      "Epoch: 004/010 training accuracy: 96.62%\n",
      "Time elapsed: 0.92 min\n",
      "Epoch: 005/010 | Batch 000/469 | Cost: 0.0469\n",
      "Epoch: 005/010 | Batch 050/469 | Cost: 0.0351\n",
      "Epoch: 005/010 | Batch 100/469 | Cost: 0.1232\n",
      "Epoch: 005/010 | Batch 150/469 | Cost: 0.0432\n",
      "Epoch: 005/010 | Batch 200/469 | Cost: 0.1049\n",
      "Epoch: 005/010 | Batch 250/469 | Cost: 0.1136\n",
      "Epoch: 005/010 | Batch 300/469 | Cost: 0.2226\n",
      "Epoch: 005/010 | Batch 350/469 | Cost: 0.1271\n",
      "Epoch: 005/010 | Batch 400/469 | Cost: 0.1405\n",
      "Epoch: 005/010 | Batch 450/469 | Cost: 0.0651\n",
      "Epoch: 005/010 training accuracy: 97.22%\n",
      "Time elapsed: 1.15 min\n",
      "Epoch: 006/010 | Batch 000/469 | Cost: 0.0886\n",
      "Epoch: 006/010 | Batch 050/469 | Cost: 0.1358\n",
      "Epoch: 006/010 | Batch 100/469 | Cost: 0.1083\n",
      "Epoch: 006/010 | Batch 150/469 | Cost: 0.0799\n",
      "Epoch: 006/010 | Batch 200/469 | Cost: 0.0815\n",
      "Epoch: 006/010 | Batch 250/469 | Cost: 0.1873\n",
      "Epoch: 006/010 | Batch 300/469 | Cost: 0.1785\n",
      "Epoch: 006/010 | Batch 350/469 | Cost: 0.1107\n",
      "Epoch: 006/010 | Batch 400/469 | Cost: 0.1059\n",
      "Epoch: 006/010 | Batch 450/469 | Cost: 0.0741\n",
      "Epoch: 006/010 training accuracy: 97.22%\n",
      "Time elapsed: 1.38 min\n",
      "Epoch: 007/010 | Batch 000/469 | Cost: 0.1303\n",
      "Epoch: 007/010 | Batch 050/469 | Cost: 0.0944\n",
      "Epoch: 007/010 | Batch 100/469 | Cost: 0.0867\n",
      "Epoch: 007/010 | Batch 150/469 | Cost: 0.1706\n",
      "Epoch: 007/010 | Batch 200/469 | Cost: 0.0840\n",
      "Epoch: 007/010 | Batch 250/469 | Cost: 0.0876\n",
      "Epoch: 007/010 | Batch 300/469 | Cost: 0.0565\n",
      "Epoch: 007/010 | Batch 350/469 | Cost: 0.0805\n",
      "Epoch: 007/010 | Batch 400/469 | Cost: 0.0784\n",
      "Epoch: 007/010 | Batch 450/469 | Cost: 0.1238\n",
      "Epoch: 007/010 training accuracy: 97.47%\n",
      "Time elapsed: 1.62 min\n",
      "Epoch: 008/010 | Batch 000/469 | Cost: 0.0740\n",
      "Epoch: 008/010 | Batch 050/469 | Cost: 0.0674\n",
      "Epoch: 008/010 | Batch 100/469 | Cost: 0.1884\n",
      "Epoch: 008/010 | Batch 150/469 | Cost: 0.0757\n",
      "Epoch: 008/010 | Batch 200/469 | Cost: 0.0633\n",
      "Epoch: 008/010 | Batch 250/469 | Cost: 0.1166\n",
      "Epoch: 008/010 | Batch 300/469 | Cost: 0.0309\n",
      "Epoch: 008/010 | Batch 350/469 | Cost: 0.0821\n",
      "Epoch: 008/010 | Batch 400/469 | Cost: 0.1253\n",
      "Epoch: 008/010 | Batch 450/469 | Cost: 0.0486\n",
      "Epoch: 008/010 training accuracy: 97.53%\n",
      "Time elapsed: 1.85 min\n",
      "Epoch: 009/010 | Batch 000/469 | Cost: 0.0538\n",
      "Epoch: 009/010 | Batch 050/469 | Cost: 0.1860\n",
      "Epoch: 009/010 | Batch 100/469 | Cost: 0.0645\n",
      "Epoch: 009/010 | Batch 150/469 | Cost: 0.0392\n",
      "Epoch: 009/010 | Batch 200/469 | Cost: 0.0662\n",
      "Epoch: 009/010 | Batch 250/469 | Cost: 0.0885\n",
      "Epoch: 009/010 | Batch 300/469 | Cost: 0.1958\n",
      "Epoch: 009/010 | Batch 350/469 | Cost: 0.0716\n",
      "Epoch: 009/010 | Batch 400/469 | Cost: 0.0790\n",
      "Epoch: 009/010 | Batch 450/469 | Cost: 0.0223\n",
      "Epoch: 009/010 training accuracy: 97.89%\n",
      "Time elapsed: 2.08 min\n",
      "Epoch: 010/010 | Batch 000/469 | Cost: 0.0982\n",
      "Epoch: 010/010 | Batch 050/469 | Cost: 0.0772\n",
      "Epoch: 010/010 | Batch 100/469 | Cost: 0.1971\n",
      "Epoch: 010/010 | Batch 150/469 | Cost: 0.0399\n",
      "Epoch: 010/010 | Batch 200/469 | Cost: 0.0341\n",
      "Epoch: 010/010 | Batch 250/469 | Cost: 0.0538\n",
      "Epoch: 010/010 | Batch 300/469 | Cost: 0.1165\n",
      "Epoch: 010/010 | Batch 350/469 | Cost: 0.1016\n",
      "Epoch: 010/010 | Batch 400/469 | Cost: 0.1560\n",
      "Epoch: 010/010 | Batch 450/469 | Cost: 0.1757\n",
      "Epoch: 010/010 training accuracy: 97.80%\n",
      "Time elapsed: 2.31 min\n",
      "Total Training Time: 2.31 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for features, targets in data_loader:\n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model = model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "    \n",
    "    model = model.eval()\n",
    "    print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "          epoch+1, num_epochs, \n",
    "          compute_accuracy(model, train_loader)))\n",
    "    \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.67%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
